import json
import math
import os
from datetime import datetime
from typing import Optional, TypedDict

import numpy as np
import pandas as pd
import pingouin as pg
from scipy.spatial import distance

from config import Config, DataConfig

pd.options.mode.chained_assignment = None  # default='warn'


class ASValues(TypedDict):
    angularSize: float
    viewingAngle: float


class DataProcessing:
    # region Trial Answers
    @staticmethod
    def get_trials_answers( part: int = None, reload_data: bool = False) -> Optional[pd.DataFrame]:
        file_path = f"{Config.DataOutputPath}/trial-answer-data.csv"

        if not os.path.isfile(file_path) or reload_data:
            DataProcessing._load_trial_answer_data()

        if part is None:
            return pd.read_csv(file_path)
        elif part in [1, 2, 3]:
            return pd.read_csv(f"{Config.DataOutputPath}/trial_answer_data-p{part}.csv")
        else:
            return None

    @staticmethod
    def _load_trial_answer_data():
        q_df = DataProcessing.get_questionnaire_data()

        output_data = {}
        for p in range(1, 4):
            output_data[p] = {
                "participant": [],
                "part": [],
                "area": [],
                "content": [],
                "trialNumber": [],
                "duration": [],
                "contentIndex": [],
                "fixedValueIndex": [],
                "distance": [],
                "tilt": [],
                "size": [],
                "angularSize": [],
                "viewingAngle": []
            }

        for folder in os.scandir(Config.InputPath):
            if folder.name in Config.IgnoreFolder or folder.name.startswith('_'):
                continue

            participant_number = folder.name[1:]

            trialLog_fp = f"{Config.InputPath}\\{folder.name}\\trialLog_{participant_number.lstrip('0')}.csv"
            trialLog_pd = pd.read_csv(trialLog_fp, sep=";", skipinitialspace=True)
            trialLog_pd["tilt"] = trialLog_pd.apply(lambda row: row["tilt"] if row["area"] == 0 else 360 - row["tilt"], axis=1)
            trialLog_pd["tilt"] = trialLog_pd.apply(lambda row: row["tilt"] if row["tilt"] != 360 else 0, axis=1)
            trialLog_pd = trialLog_pd.drop(["day", "trial"], axis=1)

            trialStart_fp = f"{Config.InputPath}\\{folder.name}\\Logging_TrialStarts-{participant_number.lstrip('0')}.csv"
            tmp_df = pd.read_csv(trialStart_fp, sep=";", skipinitialspace=True)
            trialLog_pd["startTime"] = tmp_df["time"]
            tmp_data = tmp_df.apply(lambda row: json.loads(row["data"]), axis=1)
            trialStart_pd = pd.DataFrame.from_dict(tmp_data.to_dict(), orient="index")

            trialLog_pd["contentIndex"] = trialStart_pd["contentIndex"]
            trialLog_pd["validIndex"] = list(range(120))

            indices_to_remove = []
            # Filter out trials that are not valid due to no changed input
            for t in range(trialLog_pd.shape[0]):
                valid = DataProcessing._check_for_valid_trials(
                    row_answer=trialLog_pd.iloc[[t]],
                    row_start=trialStart_pd.iloc[[t]]
                )
                if not valid:
                    indices_to_remove.append(t)

            participant_height = q_df[q_df["participant"] == int(participant_number.lstrip("0"))]["Height"].values[0]

            # Iterate over parts > areas > contentType and only save those values in the list
            for p in range(1, 4):
                for a in range(2):
                    for ct in range(2):
                        filtered_pd = trialLog_pd[
                            (trialLog_pd["part"] == p) &
                            (trialLog_pd["area"] == a) &
                            (trialLog_pd["content"] == ct)
                        ]
                        filtered_pd["tNumber"] = list(range(0 if a == 0 else 10, filtered_pd.shape[0] + (0 if a == 0 else 10)))
                        filtered_pd = filtered_pd[~(filtered_pd["validIndex"].isin(indices_to_remove))]

                        start_times = filtered_pd.apply(lambda row: datetime.strptime(row["startTime"], "%H:%M:%S.%f"), axis=1)
                        end_times = filtered_pd.apply(lambda row: datetime.strptime(row["time"], "%H:%M:%S.%f"), axis=1)

                        output_data[p]["participant"].extend([int(participant_number) for _ in range(filtered_pd.shape[0])])
                        output_data[p]["part"].extend([p for _ in range(filtered_pd.shape[0])])
                        output_data[p]["area"].extend([a for _ in range(filtered_pd.shape[0])])
                        output_data[p]["content"].extend([ct for _ in range(filtered_pd.shape[0])])
                        output_data[p]["trialNumber"].extend(list(filtered_pd["tNumber"]))
                        output_data[p]["duration"].extend([(e - s).total_seconds() for s, e in zip(start_times, end_times)])
                        output_data[p]["contentIndex"].extend(list(filtered_pd["contentIndex"]))
                        output_data[p]["distance"].extend(list(filtered_pd["distance"]))
                        output_data[p]["tilt"].extend(list(filtered_pd["tilt"]))
                        output_data[p]["size"].extend(list(filtered_pd["size"]))
                        output_data[p]["fixedValueIndex"].extend(list(filtered_pd.apply(
                            lambda row: DataProcessing._calc_fixed_value_index(row, p),
                            axis=1
                        )))
                        as_values = [DataProcessing._calc_angular_size_values(row, participant_height) for _, row in filtered_pd.iterrows()]
                        output_data[p]["angularSize"].extend([v["angularSize"] for v in as_values])
                        output_data[p]["viewingAngle"].extend([v["viewingAngle"] for v in as_values])

        for p in range(1, 4):
            output_data[p]["area"] = ["Ceiling" if a == 0 else "Floor" for a in output_data[p]["area"]]
            output_data[p]["content"] = ["Low" if ct == 0 else "Medium" for ct in output_data[p]["content"]]

        df_complete = None
        for p, d in output_data.items():
            df = pd.DataFrame.from_dict(d, orient="columns")
            df.insert(4, "conditionKey", df.apply(lambda row: f"{row['area']}-{row['content']}", axis=1))

            df.to_csv(f"{Config.DataOutputPath}/trial_answer_data-p{p}.csv")

            if df_complete is None:
                df_complete = df
            else:
                df_complete = df_complete.append(df, ignore_index=True)

        df_complete.to_csv(f"{Config.DataOutputPath}/trial-answer-data.csv")

    @staticmethod
    def _check_for_valid_trials(row_answer, row_start) -> bool:
        distance_check = math.isclose(row_start["distance"].values[0], row_answer["distance"].values[0], rel_tol=1e-5)
        tilt_check = math.isclose(row_start["tilt"].values[0], row_answer["tilt"].values[0], rel_tol=1e-5)
        size_check = math.isclose(row_start["size"].values[0], row_answer["size"].values[0], rel_tol=1e-5)

        return not (distance_check and tilt_check and size_check)
    # endregion

    # region Questionnaire Data
    @staticmethod
    def get_questionnaire_data(reload_data: bool = False) -> Optional[pd.DataFrame]:
        file_path = f"{Config.DataOutputPath}/questionnaire-data.csv"

        if not os.path.isfile(file_path) or reload_data:
            DataProcessing._load_questionnaire_data()

        return pd.read_csv(file_path)

    @staticmethod
    def _load_questionnaire_data():
        questionnaire_df = pd.read_csv(f"{Config.InputPath}/_Other/all numeric.csv", sep=";")

        for i in range(1, 7):
            questionnaire_df[f"HealthConditionDiff[SQ00{i}]"] = questionnaire_df.apply(
                lambda row: row[f"HealthConditionAfter[SQ00{i}]"] - row[f"HealthConditionPrior[SQ00{i}]"],
                axis=1
            )
        questionnaire_df = questionnaire_df.rename(columns={"ParticipationNumber": "participant"})

        questionnaire_df.to_csv(f"{Config.DataOutputPath}/questionnaire-data.csv")
    # endregion

    # region Participant Input Data
    @staticmethod
    def get_participant_input(participant: int, reload_data: bool = False) -> Optional[pd.DataFrame]:
        if participant is None:
            return None

        file_path = f"{Config.DataOutputPath}/participant input {participant}.csv"
        if not os.path.isfile(file_path) or reload_data:
            DataProcessing._load_participant_input_data(participant)

        return pd.read_csv(file_path)

    @staticmethod
    def _load_participant_input_data(participant: int = None):
        q_df = DataProcessing.get_questionnaire_data()

        for folder in os.scandir(Config.InputPath):
            participant_number = folder.name[1:]

            if folder.name in Config.IgnoreFolder or folder.name.startswith('_'):
                continue
            if participant is not None and int(participant_number) != participant:
                continue

            trialStart_fp = f"{Config.InputPath}\\{folder.name}\\Logging_TrialStarts-{participant_number.lstrip('0')}.csv"
            tmp_df = pd.read_csv(trialStart_fp, sep=";", skipinitialspace=True)
            times = tmp_df["time"]
            tmp_data = tmp_df.apply(lambda row: json.loads(row["data"]), axis=1)
            trialStart_pd = pd.DataFrame.from_dict(tmp_data.to_dict(), orient="index")
            trialStart_pd["time"] = times

            input_fp = f"{Config.InputPath}\\{folder.name}\\Logging_Input-P-{participant_number.lstrip('0')}.csv"
            tmp_df = pd.read_csv(input_fp, sep=";", skipinitialspace=True)

            participant_height = q_df[q_df["participant"] == int(participant_number.lstrip("0"))]["Height"].values[0]

            output_data = {
                "participant": [],
                "part": [],
                "area": [],
                "content": [],
                "conditionKey": [],
                "timeDelta": [],
                "distance": [],
                "tilt": [],
                "size": [],
            }

            current_trial = {}
            trial_index = 0
            new_trial = True
            for ri, row in tmp_df.iterrows():
                if new_trial:
                    trial_start_row = trialStart_pd.iloc[[trial_index]]
                    current_trial = {
                        "part": trial_start_row["part"].values[0],
                        "area": "Ceiling" if trial_start_row["area"].values[0] == 0 else "Floor",
                        "content": "Low" if trial_start_row["content"].values[0] == 0 else "Medium",
                        "startTime": datetime.strptime(trial_start_row["time"].values[0], "%H:%M:%S.%f")
                    }

                    output_data["participant"].append(int(participant_number))
                    output_data["part"].append(current_trial["part"])
                    output_data["area"].append(current_trial["area"])
                    output_data["content"].append(current_trial["content"])
                    output_data["conditionKey"].append(f"{current_trial['area']}-{current_trial['content']}")
                    output_data["timeDelta"].append(0)
                    output_data["distance"].append(trial_start_row["distance"].values[0])
                    output_data["tilt"].append(trial_start_row["tilt"].values[0])
                    output_data["size"].append(trial_start_row["size"].values[0])

                    new_trial = False

                if "info" in row["data"]:
                    trial_index += 1
                    new_trial = True
                    continue

                time = datetime.strptime(row["time"], "%H:%M:%S.%f")
                time_delta = (time - current_trial["startTime"]).total_seconds()
                values = json.loads(row["data"])

                output_data["participant"].append(int(participant_number))
                output_data["part"].append(current_trial["part"])
                output_data["area"].append(current_trial["area"])
                output_data["content"].append(current_trial["content"])
                output_data["conditionKey"].append(f"{current_trial['area']}-{current_trial['content']}")
                output_data["timeDelta"].append(time_delta)
                output_data["distance"].append(values["distance"])
                output_data["tilt"].append(values["tilt"])
                output_data["size"].append(values["size"])

            df = pd.DataFrame.from_dict(output_data, orient="columns")

            df.insert(5, "fixedValueIndex", df.apply(lambda row: DataProcessing._calc_fixed_value_index(row, part=None), axis=1))

            as_values = [DataProcessing._calc_angular_size_values(row, participant_height) for _, row in df.iterrows()]
            df["angularSize"] = [v["angularSize"] for v in as_values]
            df["viewingAngle"] = [v["viewingAngle"] for v in as_values]

            df.to_csv(f"{Config.DataOutputPath}/participant input {participant}.csv")
    # endregion

    # region Utils
    @staticmethod
    def _calc_fixed_value_index(row, part: str = None) -> Optional[float]:
        if part is None:
            part = row["part"]

        for i, fixed_value in enumerate(DataConfig.FixedValues[DataConfig.PartToFixed[part]]):
            if math.isclose(row[DataConfig.PartToFixed[part]], fixed_value, rel_tol=1e-05):
                return i

        return None

    @staticmethod
    def _calc_angular_size_values(row, height: float) -> Optional[ASValues]:
        """
        See https://stackoverflow.com/questions/18583214/calculate-angle-of-triangle-python
        """
        if height > 3:
            height = height / 100.

        # calc for the ceiling
        if row["area"] == 0 or row["area"] == "Ceiling":
            area_pos = np.array([
                row["distance"],
                Config.CeilingHeight
            ])
            rotated_size_vector = DataProcessing._rotate_vector(
                vector=np.array([
                    row["size"],
                    0
                ]),
                clockwise=True,
                angle=row["tilt"]
            )
        # calc for the floor
        elif row["area"] == 1 or row["area"] == "Floor":
            area_pos = np.array([
                row["distance"],
                0
            ])
            rotated_size_vector = DataProcessing._rotate_vector(
                vector=np.array([
                    row["size"],
                    0
                ]),
                clockwise=False,
                angle=row["tilt"]
            )
        else:
            return None

        eye_pos = np.array([
            0,
            height - Config.EyeOffset
        ])

        # Calculate the angular size
        content_end_pos = area_pos + rotated_size_vector
        a = distance.euclidean(area_pos, eye_pos)
        b = distance.euclidean(content_end_pos, eye_pos)
        s = row["size"]

        angular_size = math.degrees(math.acos((a * a + b * b - s * s) / (2. * a * b)))

        # Calculate the viewing angle
        content_end_pos = area_pos + rotated_size_vector / 2.
        a = distance.euclidean(area_pos, eye_pos)
        b = distance.euclidean(content_end_pos, eye_pos)
        s = row["size"] / 2.

        viewing_angle = math.degrees(math.acos((s * s + b * b - a * a) / (2. * s * b)))

        return {
            "angularSize": angular_size,
            "viewingAngle": viewing_angle
        }

    @staticmethod
    def _rotate_vector(vector: np.array, clockwise: bool, angle: float, origin: np.array = np.array([0, 0])) -> np.array:
        x, y = vector
        offset_x, offset_y = origin
        adjusted_x = (x - offset_x)
        adjusted_y = (y - offset_y)

        if clockwise:
            radians = math.radians(angle)
        else:
            radians = math.radians(360 - angle)
        cos_rad = math.cos(radians)
        sin_rad = math.sin(radians)

        return np.array([
            offset_x + cos_rad * adjusted_x + sin_rad * adjusted_y,
            offset_y + -sin_rad * adjusted_x + cos_rad * adjusted_y,
        ])
    # endregion


if __name__ == '__main__':
    qd_df = DataProcessing.get_questionnaire_data(reload_data=True)
    ta_df = DataProcessing.get_trials_answers(reload_data=True)

    for p in range(1, 27):
        print(f"participant {p}")
        i_df = DataProcessing.get_participant_input(participant=p, reload_data=True)
